{ "cells": [ { "cell_type": "markdown", "id": "fb2cc20e", "metadata": {}, "source": [ "# Tutorial 11 - Optimizer Utilities" ] }, { "cell_type": "markdown", "id": "fa2a995e", "metadata": {}, "source": [ "In this tutorial we will explore the optimizer utility functions provided in jaxKAN. We will cover two main optimizers: Adam with various learning rate schedules, and L-BFGS which requires special handling in Flax NNX. These optimizers are particularly useful for training KANs and PIKANs where adaptive learning rates can significantly improve convergence." ] }, { "cell_type": "code", "execution_count": 1, "id": "c8ce746d", "metadata": {}, "outputs": [], "source": [ "from jaxkan.models.KAN import KAN\n", "from jaxkan.models.utils import get_adam, get_lbfgs\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "from sklearn.model_selection import train_test_split\n", "\n", "from flax import nnx\n", "import optax\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import os\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"2\"" ] }, { "cell_type": "code", "execution_count": null, "id": "6cf8dbf1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "f02dfce6", "metadata": {}, "source": [ "## Data Generation" ] }, { "cell_type": "markdown", "id": "f560e4c7", "metadata": {}, "source": [ "We will use the same function fitting problem from Tutorial 1 to compare different optimizer configurations. Consider the function $f(x, y) = x^2 + 2\\exp(y)$, which we will fit using a KAN model with different optimization strategies." ] }, { "cell_type": "code", "execution_count": 2, "id": "d12fd8f5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training set size: (800, 2)\n", "Test set size: (200, 2)\n" ] } ], "source": [ "def f(x, y):\n", " return x**2 + 2*jnp.exp(y)\n", "\n", "def generate_data(minval=-1, maxval=1, num_samples=1000, seed=42):\n", " key = jax.random.PRNGKey(seed)\n", " x_key, y_key = jax.random.split(key)\n", "\n", " x1 = jax.random.uniform(x_key, shape=(num_samples,), minval=minval, maxval=maxval)\n", " x2 = jax.random.uniform(y_key, shape=(num_samples,), minval=minval, maxval=maxval)\n", "\n", " y = f(x1, x2).reshape(-1, 1)\n", " X = jnp.stack([x1, x2], axis=1)\n", " \n", " return X, y\n", "\n", "seed = 42\n", "\n", "X, y = generate_data(minval=-1, maxval=1, num_samples=1000, seed=seed)\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)\n", "\n", "print(\"Training set size:\", X_train.shape)\n", "print(\"Test set size:\", X_test.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "86e44b0a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "0e1983a1", "metadata": {}, "source": [ "## Part 1: Adam Optimizer" ] }, { "cell_type": "markdown", "id": "0ce104aa", "metadata": {}, "source": [ "The Adam optimizer is the most commonly used optimizer for training neural networks, including KANs. The `get_adam` function provides a convenient interface for creating Adam optimizers with various learning rate schedules and warmup strategies." ] }, { "cell_type": "markdown", "id": "d90dbfe6", "metadata": {}, "source": [ "### Experiment Setup for Adam" ] }, { "cell_type": "markdown", "id": "61d04870", "metadata": {}, "source": [ "We define a function to encapsulate the training loop. This function will allow us to easily experiment with different Adam configurations." ] }, { "cell_type": "code", "execution_count": 3, "id": "518a4234", "metadata": {}, "outputs": [], "source": [ "def run_adam_experiment(adam_config, num_epochs=2000, verbose=True):\n", " \"\"\"\n", " Run a training experiment with Adam optimizer.\n", " \n", " Args:\n", " adam_config: Dictionary with Adam optimizer parameters\n", " num_epochs: Number of training epochs\n", " verbose: Whether to print progress\n", " \n", " Returns:\n", " train_losses: Array of training losses\n", " test_loss: Final test loss\n", " \"\"\"\n", " # Initialize a KAN model\n", " n_in = X_train.shape[1]\n", " n_out = y_train.shape[1]\n", " n_hidden = 6\n", "\n", " layer_dims = [n_in, n_hidden, n_hidden, n_out]\n", " req_params = {'D': 5, 'flavor': 'exact'}\n", "\n", " model = KAN(layer_dims=layer_dims,\n", " layer_type='chebyshev',\n", " required_parameters=req_params,\n", " seed=42)\n", "\n", " # Get Adam optimizer using the utility function\n", " opt_type = get_adam(**adam_config)\n", " optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)\n", "\n", " # Define train step\n", " @nnx.jit\n", " def train_step(model, optimizer, X_train, y_train):\n", " def loss_fn(model):\n", " residual = model(X_train) - y_train\n", " loss = jnp.mean((residual)**2)\n", " return loss\n", " \n", " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", " optimizer.update(model, grads)\n", " \n", " return loss\n", "\n", " # Training loop\n", " train_losses = jnp.zeros((num_epochs,))\n", "\n", " for epoch in range(num_epochs):\n", " loss = train_step(model, optimizer, X_train, y_train)\n", " train_losses = train_losses.at[epoch].set(loss)\n", " \n", " if verbose and (epoch + 1) % 500 == 0:\n", " print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.6f}')\n", "\n", " # Evaluate on test set\n", " y_pred = model(X_test)\n", " test_loss = jnp.mean((y_pred - y_test)**2)\n", " \n", " if verbose:\n", " print(f'\\nFinal Test Loss: {test_loss:.6f}')\n", "\n", " return train_losses, test_loss" ] }, { "cell_type": "code", "execution_count": null, "id": "5dd48c1e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "5cbc2aee", "metadata": {}, "source": [ "### Constant Learning Rate" ] }, { "cell_type": "markdown", "id": "60f884fa", "metadata": {}, "source": [ "We begin with the simplest case: a constant learning rate. This is the default behavior when no schedule is specified." ] }, { "cell_type": "code", "execution_count": 4, "id": "bf004788", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [500/2000], Loss: 0.064715\n", "Epoch [1000/2000], Loss: 0.016274\n", "Epoch [1500/2000], Loss: 0.007237\n", "Epoch [2000/2000], Loss: 0.003581\n", "\n", "Final Test Loss: 0.005624\n" ] } ], "source": [ "config_constant = {\n", " 'learning_rate': 1e-3\n", "}\n", "\n", "train_losses_constant, test_loss_constant = run_adam_experiment(config_constant)" ] }, { "cell_type": "code", "execution_count": null, "id": "b41fdecd", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "d800ac0a", "metadata": {}, "source": [ "### Exponential Decay Schedule" ] }, { "cell_type": "markdown", "id": "7f579d87", "metadata": {}, "source": [ "A common strategy is to use exponential decay, where the learning rate decreases exponentially over time. This can help the model converge to a better solution by taking smaller steps as training progresses.\n", "\n", "The learning rate at step $t$ is given by:\n", "\n", "$$\\text{lr}(t) = \\text{lr}_0 \\cdot \\gamma^{t / T}$$\n", "\n", "where $\\text{lr}_0$ is the initial learning rate, $\\gamma$ is the decay rate, and $T$ is the number of decay steps." ] }, { "cell_type": "code", "execution_count": 5, "id": "75090d90", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [500/2000], Loss: 0.068539\n", "Epoch [1000/2000], Loss: 0.017559\n", "Epoch [1500/2000], Loss: 0.008215\n", "Epoch [2000/2000], Loss: 0.004304\n", "\n", "Final Test Loss: 0.006682\n" ] } ], "source": [ "config_exp_decay = {\n", " 'learning_rate': 1e-3,\n", " 'schedule_type': 'exponential',\n", " 'decay_steps': 1000,\n", " 'decay_rate': 0.9\n", "}\n", "\n", "train_losses_exp, test_loss_exp = run_adam_experiment(config_exp_decay)" ] }, { "cell_type": "code", "execution_count": null, "id": "90250d45", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "23133d54", "metadata": {}, "source": [ "### Cosine Annealing Schedule" ] }, { "cell_type": "markdown", "id": "3e9e21c1", "metadata": {}, "source": [ "Cosine annealing provides a smooth decay that follows a cosine curve. This schedule starts with the initial learning rate and gradually decreases to zero (or a minimum value) following a cosine function. It can lead to better convergence in some cases." ] }, { "cell_type": "code", "execution_count": 6, "id": "5e83c9a0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [500/2000], Loss: 0.072498\n", "Epoch [1000/2000], Loss: 0.022263\n", "Epoch [1500/2000], Loss: 0.015320\n", "Epoch [2000/2000], Loss: 0.014358\n", "\n", "Final Test Loss: 0.021979\n" ] } ], "source": [ "config_cosine = {\n", " 'learning_rate': 1e-3,\n", " 'schedule_type': 'cosine',\n", " 'decay_steps': 2000\n", "}\n", "\n", "train_losses_cosine, test_loss_cosine = run_adam_experiment(config_cosine)" ] }, { "cell_type": "code", "execution_count": null, "id": "1403a6f6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "0e6da679", "metadata": {}, "source": [ "### Warmup Strategy" ] }, { "cell_type": "markdown", "id": "45c6648c", "metadata": {}, "source": [ "Warmup is a technique where the learning rate starts at zero and linearly increases to the target learning rate over a specified number of steps. This can help stabilize training in the early stages, especially for complex models or difficult optimization landscapes.\n", "\n", "After the warmup period, the learning rate follows the specified schedule (e.g., exponential decay)." ] }, { "cell_type": "code", "execution_count": 7, "id": "7aafc6df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [500/2000], Loss: 0.432106\n", "Epoch [1000/2000], Loss: 0.046055\n", "Epoch [1500/2000], Loss: 0.018493\n", "Epoch [2000/2000], Loss: 0.008937\n", "\n", "Final Test Loss: 0.015691\n" ] } ], "source": [ "config_warmup = {\n", " 'learning_rate': 1e-3,\n", " 'schedule_type': 'exponential',\n", " 'decay_steps': 1000,\n", " 'decay_rate': 0.9,\n", " 'warmup_steps': 500\n", "}\n", "\n", "train_losses_warmup, test_loss_warmup = run_adam_experiment(config_warmup)" ] }, { "cell_type": "code", "execution_count": null, "id": "aaf67b1d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "2afa40c5", "metadata": {}, "source": [ "## Part 2: L-BFGS Optimizer" ] }, { "cell_type": "markdown", "id": "a00d0f95", "metadata": {}, "source": [ "L-BFGS (Limited-memory Broyden-Fletcher-Goldfarb-Shanno) is a quasi-Newton optimization method that can converge faster than first-order methods like Adam for smooth optimization problems. However, it requires special handling in Flax NNX due to its line search mechanism.\n", "\n", "**Key differences from Adam:**\n", "\n", "1. L-BFGS uses a line search to find optimal step sizes\n", "2. The optimizer's `update()` method requires `value` and `value_fn` arguments\n", "3. Flax NNX automatically handles the `split/merge` operations needed for the line search" ] }, { "cell_type": "markdown", "id": "32fb8201", "metadata": {}, "source": [ "### Experiment Setup for L-BFGS" ] }, { "cell_type": "markdown", "id": "55b8ee5f", "metadata": {}, "source": [ "The training loop for L-BFGS differs from Adam in how we call the `update()` method. We must provide the current loss value and a function to evaluate the loss." ] }, { "cell_type": "code", "execution_count": 8, "id": "774bb847", "metadata": {}, "outputs": [], "source": [ "def run_lbfgs_experiment(lbfgs_config, num_epochs=500, verbose=True):\n", " \"\"\"\n", " Run a training experiment with L-BFGS optimizer.\n", " \n", " L-BFGS requires special handling: the update method needs both\n", " the current loss value and a value_fn to evaluate loss at different points.\n", " \n", " Args:\n", " lbfgs_config: Dictionary with L-BFGS optimizer parameters\n", " num_epochs: Number of training epochs\n", " verbose: Whether to print progress\n", " \n", " Returns:\n", " train_losses: Array of training losses\n", " test_loss: Final test loss\n", " \"\"\"\n", " # Initialize a KAN model\n", " n_in = X_train.shape[1]\n", " n_out = y_train.shape[1]\n", " n_hidden = 6\n", "\n", " layer_dims = [n_in, n_hidden, n_hidden, n_out]\n", " req_params = {'D': 5, 'flavor': 'exact'}\n", "\n", " model = KAN(layer_dims=layer_dims,\n", " layer_type='chebyshev',\n", " required_parameters=req_params,\n", " seed=42)\n", "\n", " # Get L-BFGS optimizer using the utility function\n", " opt_type = get_lbfgs(**lbfgs_config)\n", " optimizer = nnx.Optimizer(model, opt_type, wrt=nnx.Param)\n", "\n", " # Define loss function that takes the model\n", " def loss_fn(model):\n", " residual = model(X_train) - y_train\n", " loss = jnp.mean((residual)**2)\n", " return loss\n", "\n", " # Define train step for L-BFGS\n", " @nnx.jit\n", " def train_step_lbfgs(model, optimizer):\n", " # Compute loss and gradients\n", " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", " \n", " # Update with L-BFGS\n", " # IMPORTANT: Must pass value and value_fn for line search\n", " # Flax NNX automatically handles split/merge internally\n", " optimizer.update(model, grads, value=loss, value_fn=loss_fn)\n", " \n", " return loss\n", "\n", " # Training loop\n", " train_losses = jnp.zeros((num_epochs,))\n", "\n", " for epoch in range(num_epochs):\n", " loss = train_step_lbfgs(model, optimizer)\n", " train_losses = train_losses.at[epoch].set(loss)\n", " \n", " if verbose and (epoch + 1) % 100 == 0:\n", " print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.6f}')\n", "\n", " # Evaluate on test set\n", " y_pred = model(X_test)\n", " test_loss = jnp.mean((y_pred - y_test)**2)\n", " \n", " if verbose:\n", " print(f'\\nFinal Test Loss: {test_loss:.6f}')\n", "\n", " return train_losses, test_loss" ] }, { "cell_type": "code", "execution_count": null, "id": "71d97e1c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "b53fd159", "metadata": {}, "source": [ "### L-BFGS with Default Parameters" ] }, { "cell_type": "markdown", "id": "b859ce9d", "metadata": {}, "source": [ "Let's train with L-BFGS using default parameters. Note that L-BFGS typically converges in fewer iterations than Adam." ] }, { "cell_type": "code", "execution_count": 9, "id": "dc8a9e52", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [100/100], Loss: 0.000343\n", "\n", "Final Test Loss: 0.000483\n" ] } ], "source": [ "config_lbfgs = {\n", " 'memory_size': 10\n", "}\n", "\n", "train_losses_lbfgs, test_loss_lbfgs = run_lbfgs_experiment(config_lbfgs, num_epochs=100)" ] }, { "cell_type": "code", "execution_count": null, "id": "9ae66d95", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "8e6ae90c", "metadata": {}, "source": [ "### L-BFGS with Larger Memory" ] }, { "cell_type": "markdown", "id": "157c8d2d", "metadata": {}, "source": [ "Increasing the memory size allows L-BFGS to store more past gradients, which can improve the Hessian approximation at the cost of more memory usage." ] }, { "cell_type": "code", "execution_count": 10, "id": "2bb07573", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [100/100], Loss: 0.000259\n", "\n", "Final Test Loss: 0.000384\n" ] } ], "source": [ "config_lbfgs_large = {\n", " 'memory_size': 20\n", "}\n", "\n", "train_losses_lbfgs_large, test_loss_lbfgs_large = run_lbfgs_experiment(config_lbfgs_large, num_epochs=100)" ] }, { "cell_type": "code", "execution_count": null, "id": "84f5d70a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }